import shutil
import tempfile
import pybedtools
from numpy import *

from rpy2 import robjects, rinterface
import rpy2.robjects.numpy2ri
robjects.numpy2ri.activate()
from rpy2.robjects.packages import importr
deseq = importr('DESeq2')
print("Using DESeq2 version %s" % deseq.__version__)


timepoints = (0, 1, 4, 12, 24, 96)
filename = "enhancers.expression.txt"
print("Reading", filename)
handle = open(filename)
line = next(handle)
words = line.split()
assert words[0] == 'enhancer'
libraries = words[1:]
counts = []
enhancer_names = []
n = len(libraries)
for line in handle:
    words = line.split()
    enhancer_name = words[0]
    row = zeros(n)
    for i, word in enumerate(words[1:]):
        forward, reverse = word.split(",")
        row[i] = int(forward) + int(reverse)
    dataset, locus = enhancer_name.split("|")
    assert dataset in ("FANTOM5", "HiSeq", "CAGE", "Both")
    counts.append(row)
    enhancer_names.append(enhancer_name)
handle.close()

counts = array(counts)
indices = {"CAGE": [list() for i, timepoint in enumerate(timepoints)],
           "HiSeq": [list() for i, timepoint in enumerate(timepoints)],
          }
for index, library in enumerate(libraries):
    terms = library.split("_")
    dataset = terms[0]
    if dataset == "CAGE":
        assert len(terms) == 4
        timepoint = int(terms[1])
        assert terms[2] == "hr"
        replicate = terms[3]
        assert replicate in "ABCDEFGH"
    elif dataset == "HiSeq":
        assert len(terms) == 3
        timepoint = terms[1]
        assert timepoint.startswith("t")
        timepoint = int(timepoint[1:])
        replicate = terms[2]
        assert replicate in ("r1", "r2", "r3")
        assert not (timepoint == 1 and replicate == "r3")
        # Sample negative control library prepared with water instead of RNA,
        # should not be present in the expression table
    else:
        raise Exception("Unknown dataset %s" % dataset)
    i = timepoints.index(timepoint)
    indices[dataset][i].append(index)

basemeans = zeros((len(enhancer_names), len(timepoints)+1))
pvalues = zeros((len(enhancer_names), len(timepoints)+1))
log2fcs = zeros((len(enhancer_names), len(timepoints)+1))

estimateSizeFactors = robjects.r['estimateSizeFactors']
results = robjects.r['results']

reduced = robjects.Formula("~ 1")
for i, timepoint in enumerate(timepoints):
    conditions = ['CAGE'] * len(indices['CAGE'][i]) + ['HiSeq'] * len(indices['HiSeq'][i])
    metadata = {'dataset': robjects.StrVector(conditions)}
    jj = array(indices['CAGE'][i] + indices['HiSeq'][i])
    dataframe = robjects.DataFrame(metadata)
    design = robjects.Formula("~ dataset")
    dds = deseq.DESeqDataSetFromMatrix(countData=counts[:,jj],
                                       colData=dataframe,
                                       design=design)
    dds = estimateSizeFactors(dds)
    dds = deseq.DESeq(dds, fitType="glmGamPoi", test="LRT", reduced=reduced)
    res = results(dds)
    output= res.do_slot('listData')
    names = output.names
    assert len(names)==6
    assert names[0]=='baseMean'
    assert names[1]=='log2FoldChange'
    assert names[2]=='lfcSE'
    assert names[3]=='stat'
    assert names[4]=='pvalue'
    assert names[5]=='padj'
    basemeans[:, i] = array(output[0])
    log2fcs[:, i] = array(output[1])
    pvalues[:, i] = array(output[5])

conditions_dataset = []
conditions_timepoint = []
jj = []

for i, timepoint in enumerate(timepoints):
    n_CAGE = len(indices['CAGE'][i])
    n_HiSeq = len(indices['HiSeq'][i])
    timepoint = str(timepoint)
    conditions_dataset.extend(['CAGE'] * n_CAGE)
    conditions_dataset.extend(['HiSeq'] * n_HiSeq)
    conditions_timepoint.extend([timepoint] * n_CAGE)
    conditions_timepoint.extend([timepoint] * n_HiSeq)
    jj.extend(indices['CAGE'][i])
    jj.extend(indices['HiSeq'][i])

jj = array(jj)

metadata = {'dataset': robjects.StrVector(conditions_dataset),
            'timepoint': robjects.StrVector(conditions_timepoint),
           }
dataframe = robjects.DataFrame(metadata)
design = robjects.Formula("~ timepoint + dataset")
dds = deseq.DESeqDataSetFromMatrix(countData=counts[:,jj],
                                   colData=dataframe,
                                   design=design)
dds = estimateSizeFactors(dds)
reduced = robjects.Formula("~ timepoint")
dds = deseq.DESeq(dds, fitType="glmGamPoi", test="LRT", reduced=reduced)
res = results(dds)
output = res.do_slot('listData')
names = output.names
assert len(names)==6
assert names[0]=='baseMean'
assert names[1]=='log2FoldChange'
assert names[2]=='lfcSE'
assert names[3]=='stat'
assert names[4]=='pvalue'
assert names[5]=='padj'
basemeans[:, -1] = array(output[0])
log2fcs[:, -1] = array(output[1])
pvalues[:, -1] = array(output[5])

log2fcs[isnan(log2fcs)] = 0
pvalues[isnan(pvalues)] = 1

filename = "enhancers.deseq.txt"
print("Writing", filename)
handle = open(filename, 'w')
handle.write("enhancer")
for timepoint in timepoints:
    handle.write("\t%02dhr_basemean" % timepoint)
    handle.write("\t%02dhr_log2fc" % timepoint)
    handle.write("\t%02dhr_pvalue" % timepoint)
handle.write("\tall_basemean")
handle.write("\tall_log2fc")
handle.write("\tall_pvalue")
handle.write("\n")
for enhancer_name, basemean, log2fc, pvalue in zip(enhancer_names, basemeans, log2fcs, pvalues):
    handle.write(enhancer_name)
    for b, l, p in zip(basemean, log2fc, pvalue):
        handle.write("\t%g" % b)
        handle.write("\t%g" % l)
        handle.write("\t%g" % p)
    handle.write("\n")
handle.close()
